/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.zookeeper.server;

import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;

import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.lang.management.ManagementFactory;
import java.lang.management.OperatingSystemMXBean;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.AbstractSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.jute.BinaryInputArchive;
import org.apache.jute.BinaryOutputArchive;
import org.apache.jute.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.zookeeper.Environment;
import org.apache.zookeeper.Version;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.proto.ReplyHeader;
import org.apache.zookeeper.proto.WatcherEvent;
import org.apache.zookeeper.server.quorum.Leader;
import org.apache.zookeeper.server.quorum.LeaderZooKeeperServer;
import org.apache.zookeeper.server.quorum.ReadOnlyZooKeeperServer;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.MessageEvent;

import com.sun.management.UnixOperatingSystemMXBean;

public class NettyServerCnxn extends ServerCnxn {
	private class CnxnStatResetCommand extends CommandThread {
		public CnxnStatResetCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				synchronized (factory.cnxns) {
					for (ServerCnxn c : factory.cnxns) {
						c.resetStats();
					}
				}
				pw.println("Connection stats reset.");
			}
		}
	}

	/**
	 * Set of threads for commmand ports. All the 4 letter commands are run via a
	 * thread. Each class maps to a correspoding 4 letter command. CommandThread is
	 * the abstract class from which all the others inherit.
	 */
	private abstract class CommandThread /* extends Thread */ {
		PrintWriter pw;

		CommandThread(PrintWriter pw) {
			this.pw = pw;
		}

		public abstract void commandRun() throws IOException;

		public void run() {
			try {
				commandRun();
			} catch (IOException ie) {
				LOG.error("Error in running command ", ie);
			} finally {
				cleanupWriterSocket(pw);
			}
		}

		public void start() {
			run();
		}
	}

	private class ConfCommand extends CommandThread {
		ConfCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				zkServer.dumpConf(pw);
			}
		}
	}

	private class ConsCommand extends CommandThread {
		public ConsCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				// clone should be faster than iteration
				// ie give up the cnxns lock faster
				AbstractSet<ServerCnxn> cnxns;
				synchronized (factory.cnxns) {
					cnxns = new HashSet<ServerCnxn>(factory.cnxns);
				}
				for (ServerCnxn c : cnxns) {
					c.dumpConnectionInfo(pw, false);
					pw.println();
				}
				pw.println();
			}
		}
	}

	private class DumpCommand extends CommandThread {
		public DumpCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				pw.println("SessionTracker dump:");
				zkServer.sessionTracker.dumpSessions(pw);
				pw.println("ephemeral nodes dump:");
				zkServer.dumpEphemerals(pw);
			}
		}
	}

	private class EnvCommand extends CommandThread {
		EnvCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			List<Environment.Entry> env = Environment.list();

			pw.println("Environment:");
			for (Environment.Entry e : env) {
				pw.print(e.getKey());
				pw.print("=");
				pw.println(e.getValue());
			}

		}
	}

	private class IsroCommand extends CommandThread {

		public IsroCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.print("null");
			} else if (zkServer instanceof ReadOnlyZooKeeperServer) {
				pw.print("ro");
			} else {
				pw.print("rw");
			}
		}
	}

	private class MonitorCommand extends CommandThread {

		MonitorCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
				return;
			}
			ZKDatabase zkdb = zkServer.getZKDatabase();
			ServerStats stats = zkServer.serverStats();

			print("version", Version.getFullVersion());

			print("avg_latency", stats.getAvgLatency());
			print("max_latency", stats.getMaxLatency());
			print("min_latency", stats.getMinLatency());

			print("packets_received", stats.getPacketsReceived());
			print("packets_sent", stats.getPacketsSent());

			print("outstanding_requests", stats.getOutstandingRequests());

			print("server_state", stats.getServerState());
			print("znode_count", zkdb.getNodeCount());

			print("watch_count", zkdb.getDataTree().getWatchCount());
			print("ephemerals_count", zkdb.getDataTree().getEphemeralsCount());
			print("approximate_data_size", zkdb.getDataTree().approximateDataSize());

			OperatingSystemMXBean osMbean = ManagementFactory.getOperatingSystemMXBean();
			if (osMbean != null && osMbean instanceof UnixOperatingSystemMXBean) {
				UnixOperatingSystemMXBean unixos = (UnixOperatingSystemMXBean) osMbean;

				print("open_file_descriptor_count", unixos.getOpenFileDescriptorCount());
				print("max_file_descriptor_count", unixos.getMaxFileDescriptorCount());
			}

			if (stats.getServerState().equals("leader")) {
				Leader leader = ((LeaderZooKeeperServer) zkServer).getLeader();

				print("followers", leader.learners.size());
				print("synced_followers", leader.forwardingFollowers.size());
				print("pending_syncs", leader.pendingSyncs.size());
			}
		}

		private void print(String key, long number) {
			print(key, "" + number);
		}

		private void print(String key, String value) {
			pw.print("zk_");
			pw.print(key);
			pw.print("\t");
			pw.println(value);
		}

	}

	static class ResumeMessageEvent implements MessageEvent {
		Channel channel;

		ResumeMessageEvent(Channel channel) {
			this.channel = channel;
		}

		public Channel getChannel() {
			return channel;
		}

		public ChannelFuture getFuture() {
			return null;
		}

		public Object getMessage() {
			return null;
		}

		public SocketAddress getRemoteAddress() {
			return null;
		}
	}

	private class RuokCommand extends CommandThread {
		public RuokCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			pw.print("imok");

		}
	}

	/**
	 * This class wraps the sendBuffer method of NIOServerCnxn. It is responsible
	 * for chunking up the response to a client. Rather than cons'ing up a response
	 * fully in memory, which may be large for some commands, this class chunks up
	 * the result.
	 */
	private class SendBufferWriter extends Writer {
		private StringBuffer sb = new StringBuffer();

		/**
		 * Check if we are ready to send another chunk.
		 * 
		 * @param force force sending, even if not a full chunk
		 */
		private void checkFlush(boolean force) {
			if ((force && sb.length() > 0) || sb.length() > 2048) {
				sendBuffer(ByteBuffer.wrap(sb.toString().getBytes()));
				// clear our internal buffer
				sb.setLength(0);
			}
		}

		@Override
		public void close() throws IOException {
			if (sb == null)
				return;
			checkFlush(true);
			sb = null; // clear out the ref to ensure no reuse
		}

		@Override
		public void flush() throws IOException {
			checkFlush(true);
		}

		@Override
		public void write(char[] cbuf, int off, int len) throws IOException {
			sb.append(cbuf, off, len);
			checkFlush(false);
		}
	}

	private class SetTraceMaskCommand extends CommandThread {
		long trace = 0;

		SetTraceMaskCommand(PrintWriter pw, long trace) {
			super(pw);
			this.trace = trace;
		}

		@Override
		public void commandRun() {
			pw.print(trace);
		}
	}

	private class StatCommand extends CommandThread {
		int len;

		public StatCommand(PrintWriter pw, int len) {
			super(pw);
			this.len = len;
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				pw.print("Zookeeper version: ");
				pw.println(Version.getFullVersion());
				if (zkServer instanceof ReadOnlyZooKeeperServer) {
					pw.println("READ-ONLY mode; serving only " + "read-only clients");
				}
				if (len == statCmd) {
					LOG.info("Stat command output");
					pw.println("Clients:");
					// clone should be faster than iteration
					// ie give up the cnxns lock faster
					HashSet<ServerCnxn> cnxns;
					synchronized (factory.cnxns) {
						cnxns = new HashSet<ServerCnxn>(factory.cnxns);
					}
					for (ServerCnxn c : cnxns) {
						c.dumpConnectionInfo(pw, true);
						pw.println();
					}
					pw.println();
				}
				pw.print(zkServer.serverStats().toString());
				pw.print("Node count: ");
				pw.println(zkServer.getZKDatabase().getNodeCount());
			}

		}
	}

	private class StatResetCommand extends CommandThread {
		public StatResetCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				zkServer.serverStats().reset();
				pw.println("Server stats reset.");
			}
		}
	}

	private class TraceMaskCommand extends CommandThread {
		TraceMaskCommand(PrintWriter pw) {
			super(pw);
		}

		@Override
		public void commandRun() {
			long traceMask = ZooTrace.getTextTraceLevel();
			pw.print(traceMask);
		}
	}

	private class WatchCommand extends CommandThread {
		int len = 0;

		public WatchCommand(PrintWriter pw, int len) {
			super(pw);
			this.len = len;
		}

		@Override
		public void commandRun() {
			if (zkServer == null) {
				pw.println(ZK_NOT_SERVING);
			} else {
				DataTree dt = zkServer.getZKDatabase().getDataTree();
				if (len == wchsCmd) {
					dt.dumpWatchesSummary(pw);
				} else if (len == wchpCmd) {
					dt.dumpWatches(pw, true);
				} else {
					dt.dumpWatches(pw, false);
				}
				pw.println();
			}
		}
	}

	private static final byte[] fourBytes = new byte[4];

	private static final String ZK_NOT_SERVING = "This ZooKeeper instance is not currently serving requests";

	ByteBuffer bb;;

	ByteBuffer bbLen = ByteBuffer.allocate(4);

	Channel channel;

	NettyServerCnxnFactory factory;

	boolean initialized;

	Logger LOG = LoggerFactory.getLogger(NettyServerCnxn.class);

	AtomicLong outstandingCount = new AtomicLong();

	ChannelBuffer queuedBuffer;

	long sessionId;

	int sessionTimeout;

	volatile boolean throttled;

	/**
	 * The ZooKeeperServer for this connection. May be null if the server is not
	 * currently serving requests (for example if the server is not an active quorum
	 * participant.
	 */
	private volatile ZooKeeperServer zkServer;

	NettyServerCnxn(Channel channel, ZooKeeperServer zks, NettyServerCnxnFactory factory) {
		this.channel = channel;
		this.zkServer = zks;
		this.factory = factory;
		if (this.factory.login != null) {
			this.zooKeeperSaslServer = new ZooKeeperSaslServer(factory.login);
		}
	}

	/** Return if four letter word found and responded to, otw false **/
	private boolean checkFourLetterWord(final Channel channel, ChannelBuffer message, final int len)
			throws IOException {
		// We take advantage of the limited size of the length to look
		// for cmds. They are all 4-bytes which fits inside of an int
		String cmd = cmd2String.get(len);
		if (cmd == null) {
			return false;
		}
		channel.setInterestOps(0).awaitUninterruptibly();
		LOG.info("Processing " + cmd + " command from " + channel.getRemoteAddress());
		packetReceived();

		final PrintWriter pwriter = new PrintWriter(new BufferedWriter(new SendBufferWriter()));
		if (len == ruokCmd) {
			RuokCommand ruok = new RuokCommand(pwriter);
			ruok.start();
			return true;
		} else if (len == getTraceMaskCmd) {
			TraceMaskCommand tmask = new TraceMaskCommand(pwriter);
			tmask.start();
			return true;
		} else if (len == setTraceMaskCmd) {
			ByteBuffer mask = ByteBuffer.allocate(4);
			message.readBytes(mask);

			bb.flip();
			long traceMask = mask.getLong();
			ZooTrace.setTextTraceLevel(traceMask);
			SetTraceMaskCommand setMask = new SetTraceMaskCommand(pwriter, traceMask);
			setMask.start();
			return true;
		} else if (len == enviCmd) {
			EnvCommand env = new EnvCommand(pwriter);
			env.start();
			return true;
		} else if (len == confCmd) {
			ConfCommand ccmd = new ConfCommand(pwriter);
			ccmd.start();
			return true;
		} else if (len == srstCmd) {
			StatResetCommand strst = new StatResetCommand(pwriter);
			strst.start();
			return true;
		} else if (len == crstCmd) {
			CnxnStatResetCommand crst = new CnxnStatResetCommand(pwriter);
			crst.start();
			return true;
		} else if (len == dumpCmd) {
			DumpCommand dump = new DumpCommand(pwriter);
			dump.start();
			return true;
		} else if (len == statCmd || len == srvrCmd) {
			StatCommand stat = new StatCommand(pwriter, len);
			stat.start();
			return true;
		} else if (len == consCmd) {
			ConsCommand cons = new ConsCommand(pwriter);
			cons.start();
			return true;
		} else if (len == wchpCmd || len == wchcCmd || len == wchsCmd) {
			WatchCommand wcmd = new WatchCommand(pwriter, len);
			wcmd.start();
			return true;
		} else if (len == mntrCmd) {
			MonitorCommand mntr = new MonitorCommand(pwriter);
			mntr.start();
			return true;
		} else if (len == isroCmd) {
			IsroCommand isro = new IsroCommand(pwriter);
			isro.start();
			return true;
		}
		return false;
	}

	/**
	 * clean up the socket related to a command and also make sure we flush the data
	 * before we do that
	 * 
	 * @param pwriter the pwriter for a command socket
	 */
	private void cleanupWriterSocket(PrintWriter pwriter) {
		try {
			if (pwriter != null) {
				pwriter.flush();
				pwriter.close();
			}
		} catch (Exception e) {
			LOG.info("Error closing PrintWriter ", e);
		} finally {
			try {
				close();
			} catch (Exception e) {
				LOG.error("Error closing a command socket ", e);
			}
		}
	}

	@Override
	public void close() {
		if (LOG.isDebugEnabled()) {
			LOG.debug("close called for sessionid:0x" + Long.toHexString(sessionId));
		}
		synchronized (factory.cnxns) {
			// if this is not in cnxns then it's already closed
			if (!factory.cnxns.remove(this)) {
				if (LOG.isDebugEnabled()) {
					LOG.debug("cnxns size:" + factory.cnxns.size());
				}
				return;
			}
			if (LOG.isDebugEnabled()) {
				LOG.debug("close in progress for sessionid:0x" + Long.toHexString(sessionId));
			}

			synchronized (factory.ipMap) {
				Set<NettyServerCnxn> s = factory.ipMap
						.get(((InetSocketAddress) channel.getRemoteAddress()).getAddress());
				s.remove(this);
			}

			if (channel.isOpen()) {
				channel.close();
			}
			factory.unregisterConnection(this);
		}
	}

	@Override
	public void disableRecv() {
		throttled = true;
		if (LOG.isDebugEnabled()) {
			LOG.debug("Throttling - disabling recv " + this);
		}
		channel.setReadable(false).awaitUninterruptibly();
	}

	@Override
	public void enableRecv() {
		if (throttled) {
			throttled = false;
			if (LOG.isDebugEnabled()) {
				LOG.debug("Sending unthrottle event " + this);
			}
			channel.getPipeline().sendUpstream(new ResumeMessageEvent(channel));
		}
	}

	@Override
	public int getInterestOps() {
		return channel.getInterestOps();
	}

	@Override
	public long getOutstandingRequests() {
		return outstandingCount.longValue();
	}

	@Override
	public InetSocketAddress getRemoteSocketAddress() {
		return (InetSocketAddress) channel.getRemoteAddress();
	}

	@Override
	public long getSessionId() {
		return sessionId;
	}

	@Override
	public int getSessionTimeout() {
		return sessionTimeout;
	}

	@Override
	public void process(WatchedEvent event) {
		ReplyHeader h = new ReplyHeader(-1, -1L, 0);
		if (LOG.isTraceEnabled()) {
			ZooTrace.logTraceMessage(LOG, ZooTrace.EVENT_DELIVERY_TRACE_MASK,
					"Deliver event " + event + " to 0x" + Long.toHexString(this.sessionId) + " through " + this);
		}

		// Convert WatchedEvent to a type that can be sent over the wire
		WatcherEvent e = event.getWrapper();

		try {
			sendResponse(h, e, "notification");
		} catch (IOException e1) {
			if (LOG.isDebugEnabled()) {
				LOG.debug("Problem sending to " + getRemoteSocketAddress(), e1);
			}
			close();
		}
	}

	public void receiveMessage(ChannelBuffer message) {
		try {
			while (message.readable() && !throttled) {
				if (bb != null) {
					if (LOG.isTraceEnabled()) {
						LOG.trace(
								"message readable " + message.readableBytes() + " bb len " + bb.remaining() + " " + bb);
						ByteBuffer dat = bb.duplicate();
						dat.flip();
						LOG.trace(Long.toHexString(sessionId) + " bb 0x"
								+ ChannelBuffers.hexDump(ChannelBuffers.copiedBuffer(dat)));
					}

					if (bb.remaining() > message.readableBytes()) {
						int newLimit = bb.position() + message.readableBytes();
						bb.limit(newLimit);
					}
					message.readBytes(bb);
					bb.limit(bb.capacity());

					if (LOG.isTraceEnabled()) {
						LOG.trace("after readBytes message readable " + message.readableBytes() + " bb len "
								+ bb.remaining() + " " + bb);
						ByteBuffer dat = bb.duplicate();
						dat.flip();
						LOG.trace("after readbytes " + Long.toHexString(sessionId) + " bb 0x"
								+ ChannelBuffers.hexDump(ChannelBuffers.copiedBuffer(dat)));
					}
					if (bb.remaining() == 0) {
						packetReceived();
						bb.flip();

						ZooKeeperServer zks = this.zkServer;
						if (zks == null) {
							throw new IOException("ZK down");
						}
						if (initialized) {
							zks.processPacket(this, bb);

							if (zks.shouldThrottle(outstandingCount.incrementAndGet())) {
								disableRecv();
							}
						} else {
							LOG.debug("got conn req request from " + getRemoteSocketAddress());
							zks.processConnectRequest(this, bb);
							initialized = true;
						}
						bb = null;
					}
				} else {
					if (LOG.isTraceEnabled()) {
						LOG.trace("message readable " + message.readableBytes() + " bblenrem " + bbLen.remaining());
						ByteBuffer dat = bbLen.duplicate();
						dat.flip();
						LOG.trace(Long.toHexString(sessionId) + " bbLen 0x"
								+ ChannelBuffers.hexDump(ChannelBuffers.copiedBuffer(dat)));
					}

					if (message.readableBytes() < bbLen.remaining()) {
						bbLen.limit(bbLen.position() + message.readableBytes());
					}
					message.readBytes(bbLen);
					bbLen.limit(bbLen.capacity());
					if (bbLen.remaining() == 0) {
						bbLen.flip();

						if (LOG.isTraceEnabled()) {
							LOG.trace(Long.toHexString(sessionId) + " bbLen 0x"
									+ ChannelBuffers.hexDump(ChannelBuffers.copiedBuffer(bbLen)));
						}
						int len = bbLen.getInt();
						if (LOG.isTraceEnabled()) {
							LOG.trace(Long.toHexString(sessionId) + " bbLen len is " + len);
						}

						bbLen.clear();
						if (!initialized) {
							if (checkFourLetterWord(channel, message, len)) {
								return;
							}
						}
						if (len < 0 || len > BinaryInputArchive.maxBuffer) {
							throw new IOException("Len error " + len);
						}
						bb = ByteBuffer.allocate(len);
					}
				}
			}
		} catch (IOException e) {
			LOG.warn("Closing connection to " + getRemoteSocketAddress(), e);
			close();
		}
	}

	@Override
	public void sendBuffer(ByteBuffer sendBuffer) {
		if (sendBuffer == ServerCnxnFactory.closeConn) {
			channel.close();
			return;
		}
		channel.write(wrappedBuffer(sendBuffer));
		packetSent();
	}

	/**
	 * Send close connection packet to the client.
	 */
	public void sendCloseSession() {
		sendBuffer(ServerCnxnFactory.closeConn);
	}

	@Override
	public void sendResponse(ReplyHeader h, Record r, String tag) throws IOException {
		if (!channel.isOpen()) {
			return;
		}
		ByteArrayOutputStream baos = new ByteArrayOutputStream();
		// Make space for length
		BinaryOutputArchive bos = BinaryOutputArchive.getArchive(baos);
		try {
			baos.write(fourBytes);
			bos.writeRecord(h, "header");
			if (r != null) {
				bos.writeRecord(r, tag);
			}
			baos.close();
		} catch (IOException e) {
			LOG.error("Error serializing response");
		}
		byte b[] = baos.toByteArray();
		ByteBuffer bb = ByteBuffer.wrap(b);
		bb.putInt(b.length - 4).rewind();
		sendBuffer(bb);
		if (h.getXid() > 0) {
			// zks cannot be null otherwise we would not have gotten here!
			if (!zkServer.shouldThrottle(outstandingCount.decrementAndGet())) {
				enableRecv();
			}
		}
	}

	@Override
	protected ServerStats serverStats() {
		if (zkServer == null) {
			return null;
		}
		return zkServer.serverStats();
	}

	@Override
	public void setSessionId(long sessionId) {
		this.sessionId = sessionId;
	}

	@Override
	public void setSessionTimeout(int sessionTimeout) {
		this.sessionTimeout = sessionTimeout;
	}

}
